import time
import numpy as np
from sklearn.model_selection import ParameterGrid
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from experiments.utils import load_dataset_safely, seed_everything


def grid_search_baseline(dataset_name: str, max_configs: int = 200, random_state: int = 42):
    seed_everything(random_state)
    data, msg = load_dataset_safely(dataset_name)
    if data is None:
        raise RuntimeError(msg)

    start = time.time()

    preproc_space = [
        ("imputer", [SimpleImputer(strategy=s) for s in ["mean", "median", "most_frequent"]]),
        ("scaler", [None, StandardScaler()]),
    ]

    param_grid = {
        "clf__n_estimators": [50, 100, 200],
        "clf__max_depth": [None, 5, 10, 20],
        "clf__min_samples_split": [2, 5, 10],
    }

    def build_pipeline(imputer, scaler):
        steps = [("imputer", imputer)]
        if scaler is not None:
            steps.append(("scaler", scaler))
        steps.append(("clf", RandomForestClassifier(random_state=random_state)))
        return Pipeline(steps)

    best_score = 0.0
    best_cfg = None

    X_train, y_train = data["X_train"], data["y_train"]
    X_val, y_val = data["X_val"], data["y_val"]

    # Enumerate all preproc combos and parameter grid
    tested = 0
    for imputer in preproc_space[0][1]:
        for scaler in preproc_space[1][1]:
            pipeline = build_pipeline(imputer, scaler)
            for params in ParameterGrid(param_grid):
                if tested >= max_configs:
                    break
                tested += 1
                pipeline.set_params(**params)
                try:
                    pipeline.fit(X_train, y_train)
                    preds = pipeline.predict(X_val)
                    score = accuracy_score(y_val, preds)
                except Exception:
                    score = 0.0
                if score > best_score:
                    best_score = score
                    best_cfg = (imputer, scaler, params)

    duration = time.time() - start
    return {
        "val_score": best_score,
        "best_config": str(best_cfg),
        "time_sec": duration,
    }


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="iris")
    p.add_argument("--max_configs", type=int, default=200)
    args = p.parse_args()
    res = grid_search_baseline(args.dataset, args.max_configs)
    print(res)
